{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "27f719be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import torch\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "from edm_qm9_utils.property_prediction import main_qm9_prop\n",
    "import utils_yy.utils as utils_yy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "474adc98",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(condition='alpha'):\n",
    "    qm9_atom_list = ['H', 'C', 'O', 'N', 'F']\n",
    "    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}\n",
    "    log_dir = '../AE_geom_uncond_weights_and_data/job21_latent_ddpm_qm9_spatial_graph_condition_' + condition\n",
    "\n",
    "    def get_classifier(dir_path='', device='cpu'):\n",
    "        with open(osp.join(dir_path, 'args.pickle'), 'rb') as f:\n",
    "            args_classifier = pickle.load(f)\n",
    "        args_classifier.device = device\n",
    "        args_classifier.model_name = 'egnn'\n",
    "        classifier = main_qm9_prop.get_model(args_classifier)\n",
    "        classifier_state_dict = torch.load(osp.join(dir_path, 'best_checkpoint.npy'), map_location=torch.device('cpu'))\n",
    "        classifier.load_state_dict(classifier_state_dict)\n",
    "        return classifier\n",
    "\n",
    "    model = get_classifier('../e3_diffusion_for_molecules/qm9/property_prediction/checkpoints/QM9/Property_Classifiers/exp_class_' + condition)\n",
    "\n",
    "    data = np.load( os.path.join('../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d/', 'valid.npz') )\n",
    "    cond_mean = data[condition].mean()\n",
    "    cond_mad = np.abs(data[condition] - cond_mean).mean() * 10\n",
    "\n",
    "    data = np.load( os.path.join('../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d/', 'train.npz') )\n",
    "    np.random.seed(42)\n",
    "    num_data = data['emb_2d'].shape[0]\n",
    "    idx_perm = np.random.permutation(num_data)\n",
    "    idx_train = idx_perm[num_data//2:]\n",
    "    idx_holdout = idx_perm[:num_data//2]\n",
    "    condition_train = torch.tensor((data[condition] - cond_mean) / cond_mad)[idx_train]\n",
    "    cond_max, cond_min = condition_train.max().item(), condition_train.min().item()\n",
    "    print(cond_max, cond_min, cond_mean, cond_mad)\n",
    "\n",
    "#     condition_list = torch.tensor( np.concatenate([np.linspace(cond_min, cond_max, 100) for _ in range(100)]\n",
    "#                    + [np.linspace(cond_min * 1.5, cond_min, 100) for _ in range(100)]\n",
    "#                    + [np.linspace(cond_max, cond_max * 1.5, 100) for _ in range(100)] ), dtype=torch.float32 )\n",
    "#     condition_list = condition_list * cond_mad + cond_mean\n",
    "\n",
    "    pred_max = (data[condition][idx_train].max().item() - cond_mean) / cond_mad * 10\n",
    "    pred_min = (data[condition][idx_train].min().item() - cond_mean) / cond_mad * 10\n",
    "    print(pred_max, pred_min)\n",
    "    v_max, v_min = data[condition][idx_train].max().item(), data[condition][idx_train].min().item()\n",
    "    scale = v_max - v_min\n",
    "    print(v_max, v_min)\n",
    "    print('random baseline', scale / 3)\n",
    "\n",
    "    mol_list = torch.load( osp.join(log_dir, 'sample_conformer.pt') )\n",
    "    condition_list = torch.load( osp.join(log_dir, 'condition.pt') )\n",
    "\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "\n",
    "    # for mol, cond in zip(tqdm(mol_list), condition_list):\n",
    "    for mol, cond in zip(tqdm(mol_list[:100000]), condition_list):\n",
    "        # filter out molecules with atoms not in qm9\n",
    "        atom_list = []\n",
    "        for atom in mol.GetAtoms():\n",
    "            atom_list.append(atom.GetSymbol())\n",
    "        if len(atom_list) > 29:\n",
    "            continue\n",
    "        if len(set(atom_list).difference(qm9_atom_list)) > 0:\n",
    "            continue\n",
    "\n",
    "        # featurization\n",
    "        # num_nodes_max = 29\n",
    "        # nodes = torch.cat([ torch.nn.functional.one_hot(atom_encoder[atom_list[idx]], num_classes=5).unsqueeze(dim=0) if idx < len(atom_list)\n",
    "        #                     else torch.zeros((1, 5)) for idx in range(num_nodes_max) ], dim=0)\n",
    "        nodes = torch.cat([ torch.nn.functional.one_hot(torch.tensor(atom_encoder[atom],\n",
    "            dtype=torch.int64), num_classes=5).unsqueeze(dim=0) for atom in atom_list ], dim=0).float()\n",
    "\n",
    "    #     atom_positions = torch.zeros((29, 3))\n",
    "    #     atom_positions[:len(atom_list)] = torch.tensor(mol.GetConformer().GetPositions())\n",
    "        atom_positions = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float32)\n",
    "\n",
    "        _, edge_index = utils_yy.construct_complete_graph(len(atom_list), return_index=True, add_self_loop=False)\n",
    "        edges = [edge_index[0], edge_index[1]]\n",
    "\n",
    "        atom_mask = torch.ones((len(atom_list), 1))\n",
    "        edge_mask = torch.ones((edge_index.shape[1], 1))\n",
    "\n",
    "        n_nodes = len(atom_list)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask, n_nodes=n_nodes)\n",
    "        pred[pred>pred_max] = pred_max\n",
    "        pred[pred<pred_min] = pred_min\n",
    "\n",
    "    #     print(len(atom_list), pred.item() * cond_mad / 10 + cond_mean, cond.item())\n",
    "        pred_list.append(pred.item() * cond_mad / 10 + cond_mean)\n",
    "        label_list.append(cond.item())\n",
    "\n",
    "    print(len(pred_list))\n",
    "    print(np.abs((np.array(pred_list) - np.array(label_list))).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "435fa8c6-531c-49a7-9e7f-7c192a8378f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate('alpha')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d26818-6403-4dd6-916b-b06dd560dcb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate('gap')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b6c1a889",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9102288484573364 -1.0507861375808716 75.37342 62.727723121643066\n",
      "9.102288057650055 -10.507861181150194\n",
      "132.47000122070312 9.460000038146973\n",
      "random baseline 41.003333727518715\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [13:23<00:00, 124.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "71816\n",
      "15.562691058177412\n"
     ]
    }
   ],
   "source": [
    "evaluate('alpha')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c207bead",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.947828471660614 -0.5386840105056763 0.25221726 0.390242263674736\n",
      "9.478284826116898 -5.386839999595246\n",
      "0.6220999956130981 0.041999999433755875\n",
      "random baseline 0.1933666653931141\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [17:05<00:00, 97.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77275\n",
      "0.10714337986933754\n"
     ]
    }
   ],
   "source": [
    "evaluate('gap')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1909b388",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7526819705963135 -1.1657202243804932 -0.24028876 0.1615406945347786\n",
      "7.526819599545699 -11.657201893308272\n",
      "-0.11869999766349792 -0.428600013256073\n",
      "random baseline 0.10330000519752502\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [13:17<00:00, 125.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "68286\n",
      "0.05462335903206425\n"
     ]
    }
   ],
   "source": [
    "evaluate('homo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "58d04867",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.47794172167778015 -0.4841439723968506 0.011928127 0.37990380078554153\n",
      "4.77941704545956 -4.841439664678198\n",
      "0.19349999725818634 -0.1720000058412552\n",
      "random baseline 0.12183333436648051\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [16:13<00:00, 102.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "70039\n",
      "0.06308332566491348\n"
     ]
    }
   ],
   "source": [
    "evaluate('lumo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c8580f09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9160064458847046 -0.22752515971660614 2.6750875 11.757326126098633\n",
      "19.160064322765635 -2.2752515523038177\n",
      "25.202199935913086 0.0\n",
      "random baseline 8.40073331197103\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [10:03<00:00, 165.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49756\n",
      "6.3382805550722034\n"
     ]
    }
   ],
   "source": [
    "evaluate('mu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d14c170",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4776511788368225 -0.7890406250953674 31.620028 32.11751937866211\n",
      "4.776511788526875 -7.890406281196496\n",
      "46.96099853515625 6.2779998779296875\n",
      "random baseline 13.560999552408854\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [15:48<00:00, 105.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "73815\n",
      "13.667542427023548\n"
     ]
    }
   ],
   "source": [
    "evaluate('Cv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a9ba8fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_ood(condition='alpha'):\n",
    "    qm9_atom_list = ['H', 'C', 'O', 'N', 'F']\n",
    "    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}\n",
    "    log_dir = 'logs/job8_latent_ddpm_qm9_condition_' + condition\n",
    "\n",
    "    def get_classifier(dir_path='', device='cpu'):\n",
    "        with open(osp.join(dir_path, 'args.pickle'), 'rb') as f:\n",
    "            args_classifier = pickle.load(f)\n",
    "        args_classifier.device = device\n",
    "        args_classifier.model_name = 'egnn'\n",
    "        classifier = main_qm9_prop.get_model(args_classifier)\n",
    "        classifier_state_dict = torch.load(osp.join(dir_path, 'best_checkpoint.npy'), map_location=torch.device('cpu'))\n",
    "        classifier.load_state_dict(classifier_state_dict)\n",
    "        return classifier\n",
    "\n",
    "    model = get_classifier('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules_official/qm9/property_prediction/checkpoints/QM9/Property_Classifiers/exp_class_' + condition)\n",
    "\n",
    "    data = np.load( os.path.join('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_4layer_new/', 'valid.npz') )\n",
    "    cond_mean = data[condition].mean()\n",
    "    cond_mad = np.abs(data[condition] - cond_mean).mean() * 10\n",
    "\n",
    "    data = np.load( os.path.join('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_4layer_new/', 'train.npz') )\n",
    "    np.random.seed(42)\n",
    "    num_data = data['emb_2d'].shape[0]\n",
    "    idx_perm = np.random.permutation(num_data)\n",
    "    idx_train = idx_perm[num_data//2:]\n",
    "    idx_holdout = idx_perm[:num_data//2]\n",
    "    condition_train = torch.tensor((data[condition] - cond_mean) / cond_mad)[idx_train]\n",
    "    cond_max, cond_min = condition_train.max().item(), condition_train.min().item()\n",
    "    print(cond_max, cond_min, cond_mean, cond_mad)\n",
    "\n",
    "#     condition_list = torch.tensor( np.concatenate([np.linspace(cond_min, cond_max, 100) for _ in range(100)]\n",
    "#                    + [np.linspace(cond_min * 1.5, cond_min, 100) for _ in range(100)]\n",
    "#                    + [np.linspace(cond_max, cond_max * 1.5, 100) for _ in range(100)] ), dtype=torch.float32 )\n",
    "#     condition_list = condition_list * cond_mad + cond_mean\n",
    "\n",
    "    pred_max = (data[condition][idx_train].max().item() - cond_mean) / cond_mad * 10 * 1.5\n",
    "    pred_min = (data[condition][idx_train].min().item() - cond_mean) / cond_mad * 10\n",
    "    print(pred_max, pred_min)\n",
    "    v_max, v_min = data[condition][idx_train].max().item(), data[condition][idx_train].min().item()\n",
    "    scale = (v_max - v_min) * 1.5\n",
    "    baseline_factor = (0.25 ** 3 / 3 - 0.25 ** 2 / 2 + 0.25 / 2) / 0.25\n",
    "    print(v_max, v_min)\n",
    "    print('random baseline', scale * baseline_factor)\n",
    "\n",
    "    mol_list = torch.load( osp.join(log_dir, 'sample_conformer.pt') )\n",
    "    condition_list = torch.load( osp.join(log_dir, 'condition.pt') )\n",
    "    # for mol, cond in zip(tqdm(mol_list), condition_list):\n",
    "\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "\n",
    "#     for mol, cond in zip(tqdm(mol_list[-100000:]), condition_list[-100000:]):\n",
    "    for mol, cond in zip(tqdm(mol_list[-100000:-90000]), condition_list[-100000:-90000]):\n",
    "        # filter out molecules with atoms not in qm9\n",
    "        atom_list = []\n",
    "        for atom in mol.GetAtoms():\n",
    "            atom_list.append(atom.GetSymbol())\n",
    "#         if len(atom_list) > 29:\n",
    "#             continue\n",
    "        if len(set(atom_list).difference(qm9_atom_list)) > 0:\n",
    "            continue\n",
    "\n",
    "        # featurization\n",
    "        # num_nodes_max = 29\n",
    "        # nodes = torch.cat([ torch.nn.functional.one_hot(atom_encoder[atom_list[idx]], num_classes=5).unsqueeze(dim=0) if idx < len(atom_list)\n",
    "        #                     else torch.zeros((1, 5)) for idx in range(num_nodes_max) ], dim=0)\n",
    "        nodes = torch.cat([ torch.nn.functional.one_hot(torch.tensor(atom_encoder[atom],\n",
    "            dtype=torch.int64), num_classes=5).unsqueeze(dim=0) for atom in atom_list ], dim=0).float()\n",
    "\n",
    "        AllChem.MMFFOptimizeMolecule(mol)\n",
    "    #     atom_positions = torch.zeros((29, 3))\n",
    "    #     atom_positions[:len(atom_list)] = torch.tensor(mol.GetConformer().GetPositions())\n",
    "        atom_positions = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float32)\n",
    "\n",
    "        _, edge_index = utils_yy.construct_complete_graph(len(atom_list), return_index=True, add_self_loop=False)\n",
    "        edges = [edge_index[0], edge_index[1]]\n",
    "\n",
    "        atom_mask = torch.ones((len(atom_list), 1))\n",
    "        edge_mask = torch.ones((edge_index.shape[1], 1))\n",
    "\n",
    "        n_nodes = len(atom_list)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask, n_nodes=n_nodes)\n",
    "        pred[pred>pred_max] = pred_max\n",
    "        pred[pred<pred_min] = pred_min\n",
    "\n",
    "    #     print(len(atom_list), pred.item() * cond_mad / 10 + cond_mean, cond.item())\n",
    "        pred_list.append(pred.item() * cond_mad / 10 + cond_mean)\n",
    "        label_list.append(cond.item())\n",
    "\n",
    "    print(len(pred_list))\n",
    "    print(np.abs((np.array(pred_list) - np.array(label_list))).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3e722740",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9102288484573364 -1.0507861375808716 75.37342 62.727723121643066\n",
      "13.653432086475082 -10.507861181150194\n",
      "132.47000122070312 9.460000038146973\n",
      "random baseline 73.03718820214272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [05:53<00:00, 28.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9985\n",
      "32.064227763569484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('alpha')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b1d4176b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.947828471660614 -0.5386840105056763 0.25221726 0.390242263674736\n",
      "14.217427239175347 -5.386839999595246\n",
      "0.6220999956130981 0.041999999433755875\n",
      "random baseline 0.3444343727314845\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:25<00:00, 117.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9381\n",
      "0.36313442086942527\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('gap')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2afd3305",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7526819705963135 -1.1657202243804932 -0.24028876 0.1615406945347786\n",
      "11.290229399318548 -11.657201893308272\n",
      "-0.11869999766349792 -0.428600013256073\n",
      "random baseline 0.18400313425809145\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:22<00:00, 120.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5936\n",
      "0.10930117979748938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('homo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9aeb06a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.47794172167778015 -0.4841439723968506 0.011928127 0.37990380078554153\n",
      "7.16912556818934 -4.841439664678198\n",
      "0.19349999725818634 -0.1720000058412552\n",
      "random baseline 0.2170156268402934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:01<00:00, 18.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9472\n",
      "0.17869909136619652\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('lumo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e6e47cd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9160064458847046 -0.22752515971660614 2.6750875 11.757326126098633\n",
      "28.740096484148452 -2.2752515523038177\n",
      "25.202199935913086 0.0\n",
      "random baseline 14.963806211948395\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:37<00:00, 264.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1579\n",
      "22.18587002372699\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('mu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c0cce7c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4776511788368225 -0.7890406250953674 31.620028 32.11751937866211\n",
      "7.164767682790313 -7.890406281196496\n",
      "46.96099853515625 6.2779998779296875\n",
      "random baseline 24.15553045272827\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:17<00:00, 38.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9991\n",
      "31.126622606334703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_ood('Cv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2178c79",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
